import oneflow
import contextlib
import oneflow.nn as nn

def split_batch(batch):
    if isinstance(batch, (list, tuple)):
        inputs, *targets = batch
        if len(targets)==1:
            targets = targets[0]
        return inputs, targets
    else:
        return [batch, None] 

@contextlib.contextmanager
def set_mode(model, training=True):
    ori_mode = model.training
    model.train(training)
    yield
    model.train(ori_mode)


def move_to_device(obj, device):
    if isinstance(obj, oneflow.Tensor):
        return obj.to(device=device)
    elif isinstance( obj, (list, tuple) ):
        return [ o.to(device=device) for o in obj ]
    elif isinstance(obj, nn.Module):
        return obj.to(device=device)


def flatten_dict(dic):
    flattned = dict()

    def _flatten(prefix, d):
        for k, v in d.items():
            if isinstance(v, dict):
                if prefix is None:
                    _flatten( k, v )
                else:
                    _flatten( prefix+'%s/'%k, v )
            else:
                flattned[ (prefix+'%s/'%k).strip('/') ] = v
        
    _flatten('', dic)
    return flattned
